#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 24 14:31:17 2018
This is used to convert data into xarray
@author: Lu Shen
"""
import os
import numpy as np
import glob
import pickle
import copy
import random
import re
import pandas as pd
os.chdir("/n/home03/lshen/lshen3/GC_speedup/multiple_tests_standard_new/PL_data")

#------------------------------------------------------------------------------
#---- define a function -------------------------------------------------------
def save_obj(obj, name ):
    """save data to a pickle object"""
    with open(name , 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    """load data from a pickle object"""    
    with open( name, 'rb') as f:
        return pickle.load(f)    

def optimize(Vdot,group_num=15,input_groups=[]):
    """We use the cost function Z1 in this funciton"""
    """when test for the best number of groups, we subset the data to reduce computation"""
    xx=np.array(range(0,72,2))
    yy=np.array(range(0,46,2))
    zz=np.array(range(0,59,2))
    Vdot=Vdot[xx,:,:,:]
    Vdot=Vdot[:,yy,:,:]
    Vdot=Vdot[:,:,zz,:]
    
    #--------------------------------------------------------------------
    IIPAR=len(xx)
    JJPAR=len(yy)
    LLPAR=len(zz)
    NVAR=234 
    
    #reshape the data
    slow_species_ind=np.ones(Vdot.shape)
    slow_species_ind[abs(Vdot)<100]=0
    spdata=np.zeros((IIPAR*JJPAR*LLPAR,NVAR))
    index=0
    for i in range(IIPAR):
        for j in range(JJPAR):
            for k in range(LLPAR):
                spdata[index,:]=slow_species_ind[i,j,k,:]
                index=index+1
    
    total0=np.sum(spdata==1)
    
    #non-organic halogen species are treated as fast in the stratosphere, so separate them
    family_I=[4,7,8,36,39,46,60,70,84,87,111,137,148,179,183]
    family_Br=[50,58,73,107,116,121,124,145,186,209,224,227]
    family_Cl=[74,85,95,97,150,196,198,225,233,234]
    family_I=[x-1 for x in family_I]
    family_Br=[x-1 for x in family_Br]
    family_Cl=[x-1 for x in family_Cl]
    Halogen=[]
    Halogen.extend(family_I)
    Halogen.extend(family_Br)
    Halogen.extend(family_Cl)    
    
    #create the initialization
    if len(input_groups)==0:
        print('generate groups')
        Vdot2=np.zeros(Vdot.shape)
        Vdot2[abs(Vdot)>=100]=1
        y=Vdot2.mean(axis=(0,1,2))
        li=np.array(range(NVAR))
        Z = [x for _,x in sorted(zip(y,li))]
        Z2=[]
        for x in Z:
            if not x in Halogen:
                Z2.append(x)        
        groups = np.array_split(np.array(Z2),group_num-3)
        groups=[list(x) for x in groups]
        groups.append(family_I)
        groups.append(family_Br)
        groups.append(family_Cl)  
    else:
        print('use input groups')        
        groups=input_groups
 
    def cal_time(GG):
        total=np.zeros(group_num)
        for igroup in GG:
            ind=groups[igroup]
            ap=spdata[:,ind]
            ap2=ap.sum(axis=1)
            total[igroup]=sum(ap2>0)*len(ind)
        return total    

    total_time1=cal_time(range(group_num))    
    #========================================================================
    #================Begin Simulated Annealing===============================
    #========================================================================
    iteration=1
    temperature=0.01
    while (iteration<=20000):
        if(iteration%100==0):
            print(iteration)
            
        #---update the non-halogen groups ------
        old_groups=copy.deepcopy(groups)#must use deepcopy
        ap=np.random.choice(group_num-3, 2, replace=False)#select two groups
        g1=ap[0]
        g2=ap[1]    
    
        if len(groups[g1])==0:
            continue
        #move one species from a group to another group
        ind=np.random.choice(len(groups[g1]), 1, replace=False)[0]#select number of values to exchange
        item=groups[g1][ind]
        groups[g1].remove(item)
        groups[g2].append(item)
        
        #update the cost function
        new_time=cal_time([g1,g2])[[g1,g2]]
        new_E=sum(new_time)
        prev_E=sum(total_time1[[g1,g2]])
        deltaE=(new_E-prev_E)/sum(total_time1)
        
        total_time2=total_time1.copy()
        total_time2[[g1,g2]]=new_time
        
        #determe if we accpet this change
        if (new_E<prev_E):
            print('Success: ',iteration, 'ratio: ',sum(total_time2)/total0, 'T:', temperature)
            total_time1=total_time2
        elif (random.random()< np.exp(-deltaE/temperature)):
            print('Perturb: ',iteration, 'ratio: ',sum(total_time2)/total0, 'random:', np.exp(-deltaE/temperature))
            total_time1=total_time2
        else:
            groups=copy.deepcopy(old_groups)
            
        temperature=temperature*0.999#decrease the temperature
        iteration=iteration+1

        #---every once in 5 rounds, update the halogen groups (same as above processes)------
        if (not iteration%5==0):
            continue
        old_groups=copy.deepcopy(groups)
        ap=np.random.choice([group_num-3,group_num-2,group_num-1], 2, replace=False)#select two groups
        g1=ap[0]
        g2=ap[1]        
        if len(groups[g1])==0:
            continue
        ind=np.random.choice(len(groups[g1]), 1, replace=False)[0]#select number of values to exchange
        item=groups[g1][ind]
        groups[g1].remove(item)
        groups[g2].append(item)
        
        new_time=cal_time([g1,g2])[[g1,g2]]
        new_E=sum(new_time)
        prev_E=sum(total_time1[[g1,g2]])
        deltaE=(new_E-prev_E)/sum(total_time1)
        
        total_time2=total_time1.copy()
        total_time2[[g1,g2]]=new_time
        
        if (new_E<prev_E):
            print('Success: ',iteration, 'ratio: ',sum(total_time2)/total0, 'T:', temperature)
            total_time1=total_time2
        elif (random.random()< np.exp(-deltaE/temperature)):
            print('Perturb: ',iteration, 'ratio: ',sum(total_time2)/total0, 'random:', np.exp(-deltaE/temperature))
            total_time1=total_time2
        else:
            groups=copy.deepcopy(old_groups)
        temperature=temperature*0.999
        iteration=iteration+1
    #========================================================================
    #================END Simulated Annealing===============================
    #========================================================================

    return groups        

def cal_final_ratio(Vdot,groups, group_num=10, regime_num=20):
    """Calcuate the cost function Z2"""
    #--- define functions------
    def cal_time_ratio(GG):
        total=np.zeros(group_num)
        for igroup in GG:
            ind= groups[igroup]
            ap = spdata[:,ind]
            ap2= ap.sum(axis=1)
            total[igroup]=sum(ap2>0)*len(ind)
        total0=spdata.sum()
        return total.sum()/total0 
    
    #--- prepare spdata ------
    slow_species_ind=np.ones(Vdot.shape)
    slow_species_ind[abs(Vdot)<100]=0#0 is for slow; 1 is for fast
    spdata=np.zeros((72*46*59,234))
    index=0
    for i in range(72):
        for j in range(46):
            for k in range(59):
                spdata[index,:]=slow_species_ind[i,j,k,:]
                index=index+1
                
    #--- calculate the ratio ----
    ratio1=cal_time_ratio(range(group_num))
    print('optimal ratio',ratio1)
    
    #--- find most frequent regimes -----
    labels=np.zeros((spdata.shape[0],group_num),dtype=int)
    for igroup in range(group_num):
        ap=spdata[:,groups[igroup]]
        ap2=ap.sum(axis=1)
        ap2[ap2>=1]=1
        labels[:,igroup]=ap2
        
    result=np.zeros(spdata.shape[0],dtype=int)#result is the group label
    for k in range(spdata.shape[0]):
        ind=list(labels[k,:])
        ap=''.join(map(str,ind))
        result[k]=int(ap, 2)
    
    full_chem=2**group_num-1
    y2=pd.value_counts(pd.Series(result))
    most_freq_lables=[]
    num=0
    for k in range(min([len(y2),regime_num])):
        num=num+y2.iloc[k]
        print(y2.index[k],y2.iloc[k])
        most_freq_lables.append(y2.index[k])
    print(num/result.shape[0])
    
    #-in case the full-chem is not in the most freq ones --
    most_freq_lables.append(full_chem)

    precode='{0:0'+str(group_num)+'b}'    
    group_species_num=[len(x) for x in groups]
    binary_groups=[]
    regime_species_num=[]
    for k in most_freq_lables:
        ap=precode.format(k)
        ap2=[int(x) for x in list(ap)]
        regime_species_num.append(sum(np.array(ap2)*np.array(group_species_num)))
        binary_groups.append(ap2)
    print(binary_groups)
    
    full_chem_num=np.argmax(most_freq_lables)+1
    
    #---- calculate all_types -------
    final_num=most_freq_lables
    all_binary=binary_groups
    
    all_types=np.ones(2**group_num,dtype=int)
    for num in range(2**group_num):
        ap=precode.format(num)
        ap2=np.array([int(x) for x in list(ap)])
        diff=np.ones(len(final_num),dtype=int)*99
        for kk in range(len(final_num)):
            temp=all_binary[kk]-ap2
            if (sum(temp==-1)==0):
                diff[kk]=sum(temp)
        if sum(diff<99)==0:#if there no matched groups, then use 13
            all_types[num]=full_chem_num
        else:
            all_types[num]=np.argmin(diff)+1
    
    #=======================================
    N=spdata.shape[0]
    labels_3D=np.zeros(N,dtype=int)
    num_fast_species=np.zeros(N,dtype=int)
    for i in range(N):
        ap=spdata[i,:]
        temp=np.zeros(group_num,dtype=int)
        for igroup in range(group_num):
            temp[igroup]=sum(ap[groups[igroup]])
        temp[temp>=1]=1
        ind=list(temp)
        label_num=int(''.join(map(str,ind)),2)
        ind=all_types[label_num]
        labels_3D[i]=ind
        num_fast_species[i]=regime_species_num[ind-1]
    
    ratio2=num_fast_species.sum()/spdata.sum()
    print('refined ratio',ratio2)
    
    return [ratio1,ratio2]

#------------------------------------------------------------------------------
os.chdir('./data')
files=glob.glob("*.pkl")
files.sort()
files=[x for x in files if x[17:19]=="01"]

#-----------------------------------------------------------------------------
all_group_num=[5,7,9,10,11,12,13,14,15,17,20]
all_result={}

filename=files[{run_num}-1]
strdate=re.split('_|\.',filename)[2]

for group_num in all_group_num:    
    print("==================")
    print(group_num)
    Vdot=load_obj(filename)
    groups=optimize(Vdot,group_num=group_num,input_groups=[])
    ratio=cal_final_ratio(Vdot,groups, group_num=group_num, regime_num=20)
    all_result[str(group_num)]=ratio
    save_obj(all_result,'/n/home03/lshen/lshen3/GC_speedup/multiple_tests_standard_new/PL_data_Halogen/Step1_find_group_num/results_'+strdate+'.pkl')
